Autoencoder- based representation learning of SARS-Cov-2 RNA genome sequences

Author: Dong Liang

E-mail: ldifer@gmail.com


Background

The novel coronavirus disease (COVID-19) started in late 2019 has developed into a global pandemic, posing an immediate and ongoing threat to the health and economic activities of billions of people today. The severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), which causes COVID-19, is characterized by rapid and efficient individual to individual transmission with a range of clinical courses including severe acute respiratory distress syndrome, viral pneumonia, mild upper respiratory infection (URIs) and asymptomatic carriers [1]. Covariates associated with worse outcome include hypertension, diabetes, coronary heart disease and older age. [1] Study on COVID-19 cases on the Diamond Princess cruise ship in Japan estimates the proportion of asymptomatic patients to be 17.9% (95% CrI: 15.5-20.2%)[2]. All these present great challenges for prevention and control of the COVID-19 transmission.

There are clear evidences that the SARS-Cov-2 is evolving rapidly. A recent phylogenetic network analysis of 160 SARS-Cov-2 genomes identified three central variants based on amino acid changes [3]. Yet, Tang et al found that two SNPs in strong linkage disequilibrium at location 8,782 (orf1ab: T8517C, synonymous) and 28,144 (ORF8: C251T, S84L) can form haplotypes that classified SARS-CoV-2 viruses into two major lineages (L and S types) [4]. Mutations also frequently occur in the receptor-binding domain (RBD) in the spike protein that mediates infection of human cells [5]. An recent analysis of the viral genomes of 6,000 infected people identified one mutation (named D614G) in the spike protein to be associated with increased virus transmissibility [6]. Obviously, the dynamic evolution of virus genome would have important effects on the spread, pathogenesis and immune intervention of SARS-CoV-2.

Machine learning methods have been successfully applied to classify different types of cancer and identify potentially valuable disease biomarkers [7-14]. In addition, the convolutional neural networks (CNNs) has been developed into the method of choice for medical images recognition and classification. Its special convolution and pooling architectures and parameter sharing mechanism make it computationally more efficient compared to the traditional fully connected neural networks. Albeit with its great popularity in various computer vision tasks, the CNN is less commonly employed in the field of genome sequence analysis. This study attempted to use the state-of-the-art CNN-based autoencoder and perform representation learning on 3161 full-length RNA genome sequences of SARS-Cov-2 collected from across various U.S. states and the world. The model prototype developed in this study could serve as a first step in developing disease risk scoring system in the future.

In [58]:
# from `covid19_util.seq_util import *
# from covid19_util.web_util import *
from covid19_toolkit import *
from covid19_toolkit.seq_util import *
from covid19_toolkit.model_util import *
from covid19_toolkit.model_util2 import *

# Seq processing
import pysam
import vcf
from Bio import SeqIO
from Bio.Seq import Seq

# Data processing
import numpy as np
import pandas as pd

# Plot
from plotnine import *
import pickle
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')


%load_ext autoreload
%autoreload 2
The rpy2.ipython extension is already loaded. To reload it, use:
  %reload_ext rpy2.ipython
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Virus Inforamtion

In [201]:
virus_info = load_virus_info('nucleotide')
In [205]:
len(virus_info.Country_Region.unique())
Out[205]:
37
In [67]:
virus_info['City'].unique()
Out[67]:
array([nan, 'VADODARA', 'KODINAR', 'BOTAD', 'UNA', 'JAMNAGAR', 'DAHOD',
       'JUNAGADH', 'RAJKOT', 'AHMEDABAD', 'MI', 'WI', 'IL', 'CA',
       'KARACHI', 'MODASA', 'HIMATNAGAR', 'GANDHINAGAR', 'DAHEGAM',
       'MANSA', 'BAVARIA', 'PRANTIJ', 'AK', 'FL', 'VA', 'PA', 'MD', 'VT',
       'PR', 'IA', 'DC', 'DHANSURA', 'ATHENS', 'NOVI PAZAR', 'SC', 'WA',
       'CT', 'ID', 'HYDERABAD', 'VICTORIA', 'SURAT', 'NY',
       'GUANGDONG, GUANGZHOU', 'LA', 'OR', 'MILHEEZE', 'NJ', 'GA', 'HI',
       'IN', 'MN', 'OH', 'RI', 'NV', 'NC', 'MARICOPA COUNTY, ARIZONA',
       'UT', 'KWAZULU-NATAL', 'AZ', 'KS', 'MA', 'MO', 'NE', 'NH', 'TX',
       'ANHUI, FUYANG', 'WUHAN', 'BEIJING', 'KPK', 'ANTIOQUIA',
       'HUBEI, WUHAN', 'ZHEJIANG, HANGZHOU', 'GILGIT', 'VALENCIA',
       'HO CHI MINH CITY', 'SHANGHAI', 'KERALA STATE', 'GUANGZHOU',
       'YUNNAN', 'HANGZHOU', 'SHENZHEN'], dtype=object)
In [3]:
virus_info.Length.hist()
plt.title('Length of genomic sequences (SARS-Cov-2)')
Out[3]:
Text(0.5, 1.0, 'Length of genomic sequences (SARS-Cov-2)')

Length

In [10]:
top10 = virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10).index.tolist()
virus_info_R = virus_info.loc[virus_info['Country_Region'].isin(top10), :]
In [12]:
(
    ggplot(aes(x='Country_Region', y='Length'), data=virus_info_R) + 
    geom_boxplot(alpha = 0.5) +
    geom_jitter(alpha=0.2) +
    theme(
#       figure_size=(10,6),
      legend_key=element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.),
    )
)
Out[12]:
<ggplot: (-9223363244693739979)>

Length distribution by country/region

In [73]:
g = (
    ggplot(aes(x='Length'), data=virus_info_R) +
    geom_histogram(aes(fill = 'Country_Region'), alpha = 0.9, bins = 40)+
    facet_wrap('~Country_Region', scales = "free_y") + 
    theme(
      figure_size=(10,6),
      legend_key=element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.),
    )
)

g + ggtitle('Length distribution of SARS-Cov-2 in top10 Country/Region')
Out[73]:
<ggplot: (-9223363262639078456)>

Length by collection date

In [75]:
(
    ggplot(aes(x='Collection_Date', y='Length'), data=virus_info_R) + 
    geom_point(aes(color = 'Country_Region')) +
    stat_smooth(method='loess') + 
    scale_x_date(date_breaks = "1 month", date_labels =  "%b %Y") +
    ylab('SARS-CoV-2 length (bp)') + 
    facet_wrap('~Country_Region') +
    theme(
      figure_size=(6, 4),
      legend_key=element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.),
    )
)
Out[75]:
<ggplot: (-9223363262580141880)>
In [ ]:
virus_info_R1.groupby('Country_Region').agg([mean, std]).loc[:, 'Length']
Out[ ]:
mean std
Country_Region
AUSTRALIA 29809.096685 13.460892
CHINA 29858.256410 51.889602
FRANCE 29902.740741 2.333333
GERMANY 29839.700000 51.214821
GREECE 29820.835052 14.918306
HONG KONG 29856.200000 55.161581
INDIA 29806.023438 25.190757
TAIWAN 29877.095238 82.495397
THAILAND 29815.388889 37.843546
USA 29840.069302 60.416249

Geographic location

In [68]:
count_by_country = pd.DataFrame(virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10))
count_by_country.reset_index(inplace=True)
count_by_country.columns = ['Country_Region', 'Count']
In [207]:
manufacturer_list = count_by_country.Country_Region[::-1]

(
    ggplot(aes(x='Country_Region', y='Count'), data=count_by_country) + 
    geom_bar(stat = 'identity', size=10) +
    scale_x_discrete(limits=manufacturer_list) +
    coord_flip()+
    theme(
#       figure_size=(10,6),
      legend_key=element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.),
    )
)
Out[207]:
<ggplot: (-9223363248837482748)>
In [71]:
plt.subplot(121)
virus_info.groupby('Country_Region').agg(len)['Length'].sort_values(ascending=False).head(10).plot.bar()
plt.subplot(122)
virus_info.groupby('City').agg(len)['Length'].sort_values(ascending=False).head(10).plot.bar()
plt.xlabel('City/US state')
plt.gcf().set_size_inches(10, 4)

SARS-COV-2 sequence data

Full length RNA genome

In [101]:
coding_seqs = load_virus_seq('./nucleotide.fasta')
covid19_seqs = [COVID19(id, seq) for id, seq in coding_seqs.items()]

covid19_complete = {seq.id: seq(ORF = 'complete', return_residual = False) for seq in covid19_seqs }
covid19_dataset = zero_padding(covid19_complete)

Spike

In [ ]:
spike = {seq.id: seq(ORF = 'spike', return_residual = True) for seq in covid19_seqs}
spike['MT509460']

Reference genome

In [193]:
# Reference seq
coding_seqs['NC_045512']['Severe acute respiratory syndrome coronavirus 2 isolate Wuhan-Hu-1'][:500]
Out[193]:
'ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTTCGTCCGGGTGTGACCGAAAGGTAAGATGGAGAGCCTTGTCCCTGGTTTCAACGAGAAAACACACGTCCAACTCAGTTTGCCTGTTTTACAGGTTCGCGACGTGCTCGTACGTGGCTTTGGAGACTCCGTGGAGGAGGTCTTATCAGAGGCACGTCAACATCTTAAAGATGGCACTTGTGGCTTAGTAGAAGTTGAAAAAGGCGTTTTGCCTCAACTTGAACAGCCCTATGTGTTCATCAAACGTTCGGATGCTCGAACTG'
In [195]:
seq_NC = covid19_complete['NC_045512']
seq_NC.shape
Out[195]:
(29903, 5)

Autoencoder-based representation learning

Data Preprocessing

In [3]:
# Data loading
complete_seqs = load_virus_seq('./nucleotide.fasta')
covid19_dataset = preprocessing(complete_seqs)
In [8]:
# Minibatching - custom training 
ids = [id for id, seq in covid19_dataset.items()]
seq = [tf.cast(seq, tf.float32) for id, seq in covid19_dataset.items()]
train_dataset = tf.data.Dataset.from_tensor_slices((seq, seq)).cache().shuffle(7000).batch(32)
In [4]:
# Training data - keras training
ids = [id for id, seq in covid19_dataset.items()]

ids_march = np.array(ids)[pd.Series(ids).isin(virus_info[virus_info['Collection_Date'].dt.month == 3]['Accession'])]

seq = np.array([seq.astype(np.float32) for id, seq in covid19_dataset.items() if id in ids_march])
# seq.shape

seq_all = np.array([seq.astype(np.float32) for id, seq in covid19_dataset.items()])

Train a CNN autoencoder

The autoencoder was build based on a 2D convolutional neural network with architectures of a range of combinations of maxpooling, dropout and convolutional layers. It turns out the best performance was achieved using the simple settings as shown below. However, the 2D convolutional neural network architecture did yield superior classifciation performance as compared to the 1D CNN.

  • Filter size: 5 x 5
  • No. of channel: 1, 32, 64, 64, 32, 1
  • batch_size: 32
  • Activation: ReLU
  • Convolutional/deconvolutional layers + Sigmoid layer
In [2]:
def train(model, epoch, loss_fn, train_dataset, \
            optimizer_fn,  learning_rate, print_every = 10, \
            manager = None,  **kwargs):

    # Optimizer
    optimizer = optimizer_fn(learning_rate = learning_rate)
    total_loss_train = tf.keras.metrics.Mean(name = 'total_loss_train')
#     accuracy_train = tf.keras.metrics.BinaryAccuracy(name='accuracy_train')


    # Initialize metrics
    total_loss_train.reset_states()

    for k in range(epoch):
        for _, (x, y) in enumerate(train_dataset):
            with tf.GradientTape() as type:
                # Model prediction
                yhat = model(x)
                
                # Computer loss
                total_loss = loss_fn(y, yhat)   
                loss_mean = tf.reduce_mean(total_loss)  # Average batch loss
                
                # Record metrics
                total_loss_train.update_state(total_loss) 


            # Calculate gradients of weights and biases
            grad = type.gradient(loss_mean, model.trainable_weights)
            
            # Apply gradients
            optimizer.apply_gradients(zip(grad, model.trainable_weights))

        
        if k % print_every == 0 or k == epoch - 1:        
            print('Epoch', k + 1)
            if not kwargs.keys():
                print(
                      f"Total_loss_train:{(total_loss_train.result()): {0}.{4}f}"
                      )
            else:
                # Initialize metrics
                total_loss_val.reset_states()

                # Calculate validation metrics
                for (x, y) in zip(test_dataset[0], test_dataset[1]):
                    yhat = model(x)
                    loss = loss_fun(y, yhat)
                    loss_mean = tf.reduce_mean(loss)

                    loss_val.update_state(loss_mean)
                    accuracy_val.update_state(y, yhat)

                print(
                      f"Train loss:{loss_train.result(): {0}.{4}f}", 
                      f"Train Acc:{(accuracy_train.result() * 100): {0}.{2}f}", 
                      f"Val Loss:{loss_val.result(): {0}.{4}f}", 
                      f"Val Acc:{(accuracy_val.result() * 100): {0}.{2}f}",
                      )

    return model # , loss_train, loss_val
In [1]:
# Define parameters for train_model function
# loss_fn = tf.nn.sigmoid_cross_entropy_with_logits

# CNN autoencoder parameters
params = {
    'conv_filters': [32, 64], 
    'conv_kernel_size': [5, 5], 
    'conv_stride': [1, 1],
    'convTranspose_filters': [64, 32, 1], 
    'convTranspose_kernel_size': [5, 5, 5], 
    'convTranspose_stride': [1, 1, 1],
    'latent_dim': 4
}


# Hyperparameters
model = Autoencoder_cnn3(**params)
loss_fn =  tf.keras.losses.BinaryCrossentropy()
epoch = 1
optimizer_fn = tf.optimizers.Adam # tf.train.AdamOptimizer
learning_rate = 3e-4 # 3e-4 # 5e-3     


# Train dataset
train_dataset = train_dataset # , X_minibatch10, S_mag_minibatch10p


trained_model = train(model, epoch, loss_fn, train_dataset, \
                            optimizer_fn, learning_rate = learning_rate, \
                            print_every = 10, manager = None,
#                             valid_dataset = valid_dataset,
#                             X_validation=X_valid, y_validation=y_valid
                           )
Epoch 1
Total_loss_train: 1.03
Epoch 2
Total_loss_train: 0.43
Epoch 3
Total_loss_train: 0.43
Epoch 4
Total_loss_train: 0.43
Epoch 5
Total_loss_train: 0.42
Epoch 6
Total_loss_train: 0.42
Epoch 7
Total_loss_train: 0.42
Epoch 8
Total_loss_train: 0.42
Epoch 9
Total_loss_train: 0.42
Epoch 10
Total_loss_train: 0.42
In [8]:
params = {
    'conv_filters': [32, 64], 
    'conv_kernel_size': [5, 5], 
    'conv_stride': [1, 1],
    'convTranspose_filters': [64, 32, 1], 
    'convTranspose_kernel_size': [5, 5, 5], 
    'convTranspose_stride': [1, 1, 1],
    'latent_dim': 3
}


# Hyperparameters
model_keras = Autoencoder_cnn3(**params)
# model_keras(tf.reshape(seq[0], (1, 30018, 5)))
# model.summary()
model_keras.compile(optimizer = Adam(lr=0.01), loss = tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])
model_keras.fit(seq, seq, epochs = 3, batch_size = 32)   # , batch_size = 128
Train on 2006 samples
Epoch 1/3
2006/2006 [==============================] - 1375s 685ms/sample - loss: 2.9883 - accuracy: 0.7951
Epoch 2/3
2006/2006 [==============================] - 1276s 636ms/sample - loss: 3.0660 - accuracy: 0.8012
Epoch 3/3
2006/2006 [==============================] - 1278s 637ms/sample - loss: 3.0660 - accuracy: 0.8012
WARNING: Logging before flag parsing goes to stderr.
W0615 20:00:40.144576 4597308864 deprecation.py:506] From /Users/dongliang/anaconda3/envs/tf1/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Inference

In [11]:
RERUN = False

if RERUN:
    
    pred = np.empty((2868, 3))

    pred500 = model_keras.encoder(seq_all[:500])

    pred500_1000 = model_keras.encoder(seq_all[500:1000])

    pred1000_1500 = model_keras.encoder(seq_all[1000:1500])

    pred1500_2000 = model_keras.encoder(seq_all[1500:2000])

    pred12000_ = model_keras.encoder(seq_all[2000:])

    pred = np.concatenate([pred500, pred500_1000, pred1000_1500, pred1500_2000, pred12000_], axis = 0) 
    
    pickle.dump( pred, open( "prediction_all.pkl", "wb" ) )

else:
    pred = pickle.load(open( "prediction_all.pkl", "rb" ))
In [74]:
pred_df = pd.DataFrame(np.c_[virus_info.values, pred], columns = list(virus_info.columns) + ['d1', 'd2', 'd3'])
pred_df.d1 = pred_df.d1.astype('float')
pred_df.d2 = pred_df.d2.astype('float')
pred_df.d3 = pred_df.d3.astype('float')
In [75]:
# By month
pred_USA = pred_df.query("Country_Region == 'USA'")
pred_USA['Collection_month'] = pred_USA.Collection_Date.dt.month
pred_USA['Collection_month'].replace(to_replace = {1: "1, Jan", 2: "2, Feb", 3: '3, Mar', 4: '4, Apr', 5: '5, May'}, inplace = True)

# By state
gt = pred_USA.groupby('City').count()['Accession'] > 100
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]

# By location
pred_df_top10 = pred_df.loc[pred_df.Country_Region.isin(top10), :]
/Users/dongliang/anaconda3/envs/tf1/lib/python3.6/site-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until
/Users/dongliang/anaconda3/envs/tf1/lib/python3.6/site-packages/pandas/core/generic.py:6746: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._update_inplace(new_data)

Visualization

3D clustering

In [32]:
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
# %matplotlib widget

d = pred_df.loc[:, ['d1', 'd2', 'd3']].values.astype('float')

fig = plt.figure(figsize=(8,7))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(d[:, 0], d[:, 1], d[:, 2], c='r', marker='o')
Out[32]:
<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fe239348710>

By location

In [72]:
ref = pred_df.loc[pred_df.Accession == 'NC_045512', :]
top4 = pred_df.loc[pred_df.Country_Region.isin(['CHINA', 'USA', 'AUSTRALIA', 'GERMANY']), :]

(
    ggplot(aes(x='d1', y='d2', color = 'Country_Region'), data=top4) +  # pred_df_top10
    geom_point() +
    geom_point(aes(x='d1', y='d2'), color = 'black', shape = '*', size= 10, alpha = 0.3, data = ref) + 
#     scale_y_discrete(minor_breaks=[]) + 
#     scale_x_discrete(limits=manufacturer_list) +
#     coord_flip()+
    xlab('Dimension 1') + 
    ylab('Dimension 2') + 
    facet_wrap('~Country_Region') + 
    theme(
      figure_size=(10,6),
      legend_key=element_blank(),
      legend_position = "top",
      legend_title = element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.)
      
    )
)
Out[72]:
<ggplot: (-9223363248736475325)>

By state

In [97]:
%matplotlib inline

# By state
gt = pred_USA.groupby('City').count()['Accession'] > 43
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]


(
    ggplot(aes(x='d1', y='d2', color = 'City'), data=pred_USA_select) + 
    geom_point() +
    xlab('Dimension 1') + 
    ylab('Dimension 2') +
    facet_wrap('~City') + 
    theme(
      figure_size=(10,6),
      legend_key=element_blank(),
      legend_position = "top",
      legend_title = element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.)
      
    )
)
/Users/dongliang/anaconda3/envs/tf1/lib/python3.6/site-packages/plotnine/ggplot.py:729: PlotnineWarning: Saving 10 x 6 in image.
  from_inches(height, units), units), PlotnineWarning)
/Users/dongliang/anaconda3/envs/tf1/lib/python3.6/site-packages/plotnine/ggplot.py:730: PlotnineWarning: Filename: by_state.svg
  warn('Filename: {}'.format(filename), PlotnineWarning)
Out[97]:
<ggplot: (-9223363248803779068)>

By month

In [99]:
%matplotlib inline

gt = pred_USA.groupby('City').count()['Accession'] > 100
city_list = np.array(pred_USA.groupby('City').count()['Accession'].index[gt])
pred_USA_select = pred_USA.loc[pred_USA.City.isin(city_list), :]


(
    ggplot(aes(x='d1', y='d2', z = 'd3', color = 'Collection_month'), data=pred_USA_select) + 
    geom_point() +
    facet_grid('City ~ Collection_month') + 
    theme(
      figure_size=(12,8),
      legend_key=element_blank(),
      legend_position = "top",
      legend_title = element_blank(),
      axis_text_x = element_text(rotation=45, hjust=1.)
      
    )
)
Out[99]:
<ggplot: (-9223363248736278440)>
In [ ]: